from evaluation_utils import normalize_final_answer, remove_boxed, last_boxed_only_string
import re
groundtruth = [line.rstrip() for line in open("math_groundtruth.txt", "r").readlines()]
normallines = [line.split("<|endoftext|>")[-2].split("\\n<|assistant|>\\n")[-1] for line in open("normal_math_outputs.txt", "r").readlines()]
perturbedlines = [line.split("<|endoftext|>")[-2].split("<|end_header_id|>\\n\\n")[-1] for line in open("perturbed_math_outputs.txt", "r").readlines()]
for index, value in enumerate(normallines):
    if normalize_final_answer(remove_boxed(last_boxed_only_string(value))) != "":
        normallines[index] = normalize_final_answer(remove_boxed(last_boxed_only_string((value))))
    else:
        if len(re.findall(r"[-+]?\d*\.\d+|\d+", value)) != 0:
            normallines[index] = re.findall(r"[-+]?\d*\.\d+|\d+", value)[-1]
        else:
            normallines[index] = ""
for index, value in enumerate(perturbedlines):
    if normalize_final_answer(remove_boxed(last_boxed_only_string(value))) != "":
        perturbedlines[index] = normalize_final_answer(remove_boxed(last_boxed_only_string((value))))
    else:
        if len(re.findall(r"[-+]?\d*\.\d+|\d+", value)) != 0:
            perturbedlines[index] = re.findall(r"[-+]?\d*\.\d+|\d+", value)[-1]
        else:
            perturbedlines[index] = ""
normalcorrect = 0
perturbedcorrect = 0
for index, answer in enumerate(groundtruth):
    print(perturbedlines[index])
    print("--------------------------------------")
    if normallines[index] == answer:
        normalcorrect += 1
    if perturbedlines[index] == answer:
        perturbedcorrect += 1
print("Normal:", normalcorrect / len(groundtruth))
print("Perturbed:", perturbedcorrect / len(groundtruth))